import copy
import json
import os
import pickle
from math import floor

import numpy as np
import torch


# ara= torch.randn(10)*4
# print(ara)
# ret= map_to_exact_discrete(ara)
# print(ret)
from matplotlib import pyplot as plt
from numpy import uint8

from Causal_Partial_Mnist.Find_CF_Synthetic_Distribution_Mnist import get_intv_dist
from Causal_Partial_Mnist.RejectionSampling_Optimized import rejection_sampling_optimized
from Causal_Partial_Mnist.True_Counterfactuals_Mnist import get_cf_dist
from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.FunctionsConstant import getdoKey, asKey
from ModularUtils.FunctionsDistribution import calculate_TVD, match_with_true_dist, get_joint_distributions_from_samples


def check_query(Exp):
    feat = "feature"
    cur_obs= ["Ycolor"]
    cur_intv_query = {"X1": 1, "X2": 9}
    cur_evidence = {"X1p": 0, "X2p": 0}
    true_cf_dist = get_cf_dist(Exp, cur_obs, cur_intv_query, cur_evidence, "testing_cf", load_dist=True)

    query_str = getdoKey(cur_obs, dict(cur_intv_query))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, cur_obs, dict(cur_intv_query),
                             query_str)  # getting the obs distribution of intv variables
    tvd = calculate_TVD(true_cf_dist, obs_dist, doPrint=False)
    print(tvd)



    cfquery = Exp.cf_queries[0]
    evidence_list= [evidence for evidence in cfquery["evidence"]]

    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               Exp.Synthetic_Sample_Size,
                                                                                               evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)

    kev = asKey(cur_evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], \
                                                      all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              cur_intv_query, cur_obs, Exp.Synthetic_Sample_Size, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cur_obs)

    true_cf_dist = get_cf_dist(Exp, cur_obs, cur_intv_query, cur_evidence, cfquery["expr"], load_dist=True)
    cf_tvd, cf_kl = match_with_true_dist(Exp, cur_obs, cf_samples, true_cf_dist, feat,
                                         doPrint=False)  # get it from scm


    fake_cf_dist= get_joint_distributions_from_samples(Exp, cur_obs, cf_samples, feat)

    #
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, cur_intv_query, cur_obs,
                                                 Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, cur_obs)
    fake_intv_dist= get_joint_distributions_from_samples(Exp, cur_obs, generated_labels_full, feat)

    tvd = calculate_TVD(fake_cf_dist, fake_intv_dist, doPrint=False)
    print("cf vs intv", tvd)



    return




def get_conditional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    intv_key= {}
    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    # tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1,-1))


    # obs_conf=(0,8)
    # obs_conf=(1,4)
    obs_conf=(0,1)
    all_samples = sample_dict[obs_conf]
    randices= torch.randint(0, len(all_samples), (10,)).tolist()

    samples= [all_samples[idx] for idx in randices]

    sample_dict= dict(sorted(sample_dict.items()))
    result={"obs_comb":[], "prob": [], "loss":[]}
    losses={}
    prob_dict={}

    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:",key)
        gen_labels= sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)
            cnt= counts[ind]
            fake_prob=counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob )
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd=tuple(ky.tolist()[0])
            loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())
            losses[tuple(ky.tolist()[0])]=loss.item()
            prob_dict[tuple(ky.tolist()[0])]= fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")


    chosen_samples={"obs_comb":[], "prob": [], "loss":[]}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):

            l1= ss.tolist()[0]
            l2= jj[0]


            if tuple(l1)==tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])


    print("---")


    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id]  for id in indices]
    l2 = [chosen_samples["prob"][id]   for id in indices]
    l3 = [chosen_samples["loss"][id]   for id in indices]


    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx,_ in enumerate(l1):
        print(l1[idx],  l2[idx], l3[idx])


    save_res={"obs_comb":l1, "prob": l2, "loss":l3}

    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/"+str(obs_conf)+"P(V|X1,X2).txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return




def get_highest_conditional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    # intv_key= {"X1":1, "X2":4}
    intv_key= {}
    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    # tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    # perms = generate_permutations([2, 9]).tolist()
    # obs_key_val = [dict(zip(["X1", "X2"], comb)) for comb in perms]

    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1,-1))


    sample_dict= dict(sorted(sample_dict.items()))
    result={"obs_comb":[], "prob": [], "loss":[]}
    losses={}
    prob_dict={}

    for key in sample_dict:

        print("key:",key)
        gen_labels= sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)

            cnt= counts[ind]
            # print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            fake_prob=counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob )
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd=tuple(ky.tolist()[0])
            loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())
            losses[tuple(ky.tolist()[0])]=loss.item()
            prob_dict[tuple(ky.tolist()[0])]= fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")



    for idx,_ in enumerate(result["obs_comb"]):
        print(result["obs_comb"][idx],  result["prob"][idx], result["loss"][idx])



    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/P(V|X1,X2).txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return

def get_highest_interventional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    intv_vars=["X1","X2"]
    perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in intv_vars])
    key_vals = [dict(zip(intv_vars, comb)) for comb in perms]

    result = {"obs_comb": [], "loss": []}
    for intv_key in key_vals:
        true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
        _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
        # _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
        tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
        generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


        sample_dict={}

        for row in generated_labels_full:
            x,y =row[0], row[1]
            key= tuple((x,y))
            if key not in sample_dict:
                sample_dict[key]=[]
            # if len(sample_dict[key])<4:
            sample_dict[key].append(row)

        sample_dict= dict(sorted(sample_dict.items()))



        for key in sample_dict:
            print("key:",key)
            gen_labels= torch.tensor(sample_dict[key])
            uniques, _, counts = torch.unique(torch.tensor(gen_labels), sorted=True, return_inverse=True, return_counts=True, dim=0)

            # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
            _, indices = torch.sort(counts, dim=0, descending=True)
            for ind  in indices[0:1]:
                ky= uniques[ind]
                cnt= counts[ind]
                # print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
                print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0])
                # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
                # loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[tuple(ky.tolist())])
                loss= abs(cnt/generated_labels_full.shape[0] -true_dist_dict[tuple(ky.tolist())])
                print("loss",loss)
                result["obs_comb"].append(ky.tolist())
                result["loss"].append(loss.item())

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")



    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/P(V|do(X1,X2)).txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return


def get_interventional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    intv_vars=["X1","X2"]
    perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in intv_vars])
    key_vals = [dict(zip(intv_vars, comb)) for comb in perms]

    result = {"obs_comb": [], "loss": []}

    obs_conf = (1, 4)
    intv_key={"X1":1,"X2":4}
    # obs_conf=(0,8)
    # intv_key={"X1":0,"X2":8}
    num_samples= 10


    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    # _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        # sample_dict[key].append(row)
        sample_dict[key].append(torch.tensor(row).view(1,-1))



    all_samples = sample_dict[obs_conf]
    randices = torch.randint(0, len(all_samples), (num_samples,)).tolist()
    samples = [all_samples[idx] for idx in randices]

    sample_dict= dict(sorted(sample_dict.items()))
    result = {"obs_comb": [], "prob": [], "loss": []}
    losses = {}
    prob_dict = {}


    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:",key)
        # gen_labels= torch.tensor(sample_dict[key])
        gen_labels = sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(torch.tensor(gen_labels), sorted=True, return_inverse=True, return_counts=True, dim=0)

        # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)
            cnt= counts[ind]
            fake_prob= counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:",fake_prob)
            dd=tuple(ky.tolist()[0])
            loss= abs(cnt/generated_labels_full.shape[0] -true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())



    chosen_samples = {"obs_comb": [], "prob": [], "loss": []}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):
            l1 = ss.tolist()[0]
            l2 = jj[0]
            if tuple(l1) == tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])

    print("---")

    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id] for id in indices]
    l2 = [chosen_samples["prob"][id] for id in indices]
    l3 = [chosen_samples["loss"][id] for id in indices]

    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx, _ in enumerate(l1):
        print(l1[idx], l2[idx], l3[idx])

    save_res = {"obs_comb": l1, "prob": l2, "loss": l3}


    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/"+str(obs_conf)+"P(V|do(X1,X2)).txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return


def get_cf_samples_for_image(Exp, label_generators):
    result = {"obs_comb": [], "prob": [], "loss": []}

    obs_vars = ["X1", "X2", "W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    evidence={"X1p":1, "X2p":6}
    intv_key={"X1":1, "X2":5}
    obs_conf=(1,5)
    # obs_conf = (1, 4)
    num_samples=20

    feat = "feature"

    n_samples = Exp.Synthetic_Sample_Size

    cf_list = [
        {"intv": ["X1", "X2"], "evid": ["X1p", "X2p"], "expr": "P(V|do(X1,X2),X1p, X2p)"}]
    cf_queries = []
    for cf in cf_list:
        perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in cf["intv"]]).tolist()

        intv_key_val = [dict(zip(cf["intv"], comb)) for comb in perms]

        perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in cf["evid"]]).tolist()
        ev_key_val = [dict(zip(cf["evid"], comb)) for comb in perms]

        cf_queries.append({"obs": obs_vars, "intervs": intv_key_val, "evidence": ev_key_val, "expr": cf["expr"]})

    evidence_list= [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               n_samples, evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)
    kev = asKey(evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], \
                                                      all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              intv_key, obs_vars, n_samples, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, obs_vars)


    #

    sample_dict = {}

    for row in cf_samples:
        x, y = row[0], row[1]
        key = tuple((x, y))
        if key not in sample_dict:
            sample_dict[key] = []
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1, -1))


    all_samples = sample_dict[obs_conf]
    randices = torch.randint(0, len(all_samples), (num_samples,)).tolist()

    samples = [all_samples[idx] for idx in randices]

    sample_dict = dict(sorted(sample_dict.items()))
    result = {"obs_comb": [], "prob": [], "loss": []}
    losses = {}
    prob_dict = {}

    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:", key)
        gen_labels = sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind in indices:
            ky = uniques[ind].view(1, -1)
            cnt = counts[ind]
            fake_prob = counts[ind] / gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob)
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd = tuple(ky.tolist()[0])
            # loss = abs(counts[ind] / gen_labels.shape[0] - true_dist_dict[dd])
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(100)
            result["prob"].append(fake_prob.item())
            # losses[tuple(ky.tolist()[0])] = loss.item()
            prob_dict[tuple(ky.tolist()[0])] = fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")

    chosen_samples = {"obs_comb": [], "prob": [], "loss": []}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):

            l1 = ss.tolist()[0]
            l2 = jj[0]

            if tuple(l1) == tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])

    print("---")

    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id] for id in indices]
    l2 = [chosen_samples["prob"][id] for id in indices]
    l3 = [chosen_samples["loss"][id] for id in indices]

    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx, _ in enumerate(l1):
        print(l1[idx], l2[idx], l3[idx])

    save_res = {"obs_comb": l1, "prob": l2, "loss": l3}

    ev_str= "".join(str(x) for x in evidence.values())
    in_str="".join(str(x) for x in intv_key.values())
    file_name = "/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/"+ in_str+"|"+ev_str+"P(Y|do(X1,X2),X1p,X2p).txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return

def get_highest_cf_samples_for_image(Exp, label_generators):

    result = {"obs_comb": [], "prob":[], "loss": []}


    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    feat = "feature"
    cfquery = Exp.cf_queries[0]

    if bool(set(cfquery["obs"]) & set(cur_mechs)) == False:
        return tvd_diff, kl_diff

    evidence_vars = [Exp.twin_map[lb] for lb in cfquery["evidence"][0].keys()]
    compare_Var = list(evidence_vars)  # getting the intervened variables
    query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, compare_Var, dict({}), query_str)  # getting the obs distribution of intv variables

    final_tvd = 0
    final_kl = 0

    n_samples = Exp.Synthetic_Sample_Size

    evidence_list = [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               n_samples, evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)


    evidence_list=[{"X1p":0, "X2p":0}, {"X1p":0, "X2p":1}, {"X1p":0, "X2p":2},
                   {"X1p":1, "X2p":6}, {"X1p":1, "X2p":7}, {"X1p":1, "X2p":8}]

    intv_dict=[{0:[3,5]}, {1:[3,5]}, {2:[4,7]},
               {3:[1,3]}, {4:[1,3]}, {5:[6,7]}]



    kev = asKey(evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              intv_key, cfquery["obs"], n_samples, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cfquery["obs"])

    upd_dist = get_joint_distributions_from_samples(Exp, observed_var, samples, feature)

    # true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)

    print(f"CF query done for evidence:{evidence}, intv_key: {intv_key} ")


    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/P(Y|do(X1,X2),X1',X2').txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return


Exp = Experiment("Exp1", set_DGP,
                 dist_thresh=0.15,
                 causal_hierarchy=2,
                 Temperature=1,
                 temp_min=0.1,
                 G_hid_dims=[256, 256],
                 D_hid_dims=[256, 256],
                 # IMAGE_FILTERS=[512, 256, 128],
                 IMAGE_FILTERS=[128, 64, 32],
                 CRITIC_ITERATIONS=5,
                 LAMBDA_GP=3,
                 learning_rate=2 * 1e-4,
                 Synthetic_Sample_Size=40000,
                 intv_Sample_Size=40000,
                 batch_size=200,
                 features=["feature"],
                 noise_states=100,
                 latent_state=16,
                 Data_intervs=[{}],
                 # Data_observs=[
                 #     {"cond": ["D"], "obs": ["I"], "mech":["I"]},
                 #     {"cond": ["I"], "obs": ["D", "C"], "mech":["D", "C"]}
                 # ],
                 num_epochs=300,
                 new_experiment=False
                 )

# get_expected_true_cf(Exp)
Exp.Synthetic_Sample_Size = 20000
Exp.intv_batch_size = Exp.batch_size

# compare_Var = ["Ycolor"]  # getting the intervened variables
# query_str = getdoKey(compare_Var, dict({"X1":0, "X2":1}))  # getting the scm saving file name
# obs_dist = get_intv_dist(Exp, compare_Var, dict({}), query_str)  # getting the obs distribution of intv variables
# print("P(Ycolor|do(X1,X2)",obs_dist)

# ret = get_expected_true_intervs(Exp)
#
# for x2 in range(9):
#     true_bn, _ = get_bayesian_network(Exp, {"X2":x2}, load_scm=1)
#     _, _, _, true_dist_dict = get_synthetic_dist(Exp, ["Ycolor"], true_bn["feature"])
#     print(f" X2: {x2}, {true_dist_dict}")
#
#
# bn_dict, INSTANCES = get_bayesian_network(Exp, {}, load_scm=1)
# ret = get_cond_synthetic_dist(["Ycolor"], ["X2"], ["Ycolor","X2"], bn_dict["feature"])
# print(ret)
#
# Exp.true_bn = {}



#Expectation is not good to take.
# intv = get_expected_true_intervs(Exp)
# cf= get_expected_true_cf(Exp)
# true_intvVscf = calculate_TVD(intv, cf, doPrint=False)


SHARED_INFO = "/path_to_project/SAVED_EXPERIMENTS/"+Exp.Complete_DAG_desc+"/SHARED_INFO.txt"
with open(SHARED_INFO) as f:
    data = f.read()
INSTANCE = json.loads(data)

last_exp = INSTANCE["last_exp"]
# last_exp ="/path_to_project/SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Sep_19_2022-14_15"
# last_exp ="/path_to_project/SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Sep_20_2022-17_15"
print(last_exp)
Exp.LOAD_MODEL_PATH = last_exp

# load_which_models = {"X1": True, "X2": True, "W": True, "Ydigit1": True, "Ydigit2": True, "Ycolor": True,
#                          "Ythick": True}

# load_which_models = {"X1": False, "X2": False, "W": False, "Ydigit1": False, "Ydigit2": False, "Ycolor": False,
#                                                   "Ythick": False}

# load_which_models = {"D": True, "I": True, "C": True}
Exp.load_which_models = {"D": True, "I": True, "RI": True, "C": True}

# cur_mechs = ["Ydigit1", "Ydigit2", "Ythick"]
# cur_mechs = ["X1", "X2", "W", "Ycolor"]

cur_mechs=["D", "I", "C"]
compare_Var=[ "C"]

label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)

for gen in label_generators:
    label_generators[gen].eval()

with torch.no_grad():
    # --------
    intv_pro = {"D":0}
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_pro, cur_mechs, 100)
    # generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_pro, compare_Var,
    #                                              Exp.Synthetic_Sample_Size)

    generated_image = generated_labels_dict[Exp.image_labels[0]]
    del generated_labels_dict[Exp.image_labels[0]]

    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

    dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                                                             generated_labels_full, "feature")

    print(dataset_dist_dict)
    # for genimg in generated_image:
    for grow, genimg in zip(generated_labels_full, generated_image):
        genimg = genimg.permute(1, 2, 0).detach().cpu().numpy()
        plot_trained_digits(1, 1, [genimg], [f'Digit:{grow[0]} Color:{["red", "green", "blue", "white"][grow[1]]}'])



    # ----------




    cfquery = Exp.cf_queries[0]
    evidence_list= [evidence for evidence in cfquery["evidence"]]

    # n_samples = Exp.Synthetic_Sample_Size
    # posterior_label, posterior_latent, gumbel_noise = rejection_sampling_optimized(Exp, label_generators, n_samples, evidence_list,
    #                                                                          max_rejections=0, warn=100)






    # get_conditional_sample_for_images(Exp, label_generators)
    # get_interventional_sample_for_images(Exp, label_generators)
    # get_cf_samples_for_image(Exp, label_generators)


    ######
    # check_query(Exp)


    #####


    #Todo:  intervened images

    feat = "feature"
    tvd_diff = {}
    kl_diff = {}
    obs_vars= ["X1", "X2","W", "Ydigit1", "Ydigit2"]

    obs_query = getdoKey(obs_vars, [])
    tvd_diff[obs_query] = []
    kl_diff[obs_query] = []



    query_string= getdoKey(obs_vars, {})
    tvd_diff[query_string]=[]
    kl_diff[query_string]=[]

    for query in Exp.interv_queries:
        tvd_diff[query["expr"]] = []
        kl_diff[query["expr"]] = []

    for query in Exp.cf_queries:
        tvd_diff[query["expr"]] = []
        kl_diff[query["expr"]] = []



    # tvd_diff, kl_diff, true_intv, fake_intv  = get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff)
    # tvd_diff, kl_diff, true_cf, fake_cf = get_expected_loss_countefactuals(Exp, cur_mechs, label_generators, tvd_diff, kl_diff)
    # tvd_intv= calculate_TVD(true_intv, fake_intv, doPrint=False)
    # tvd_cf= calculate_TVD(true_cf, fake_cf, doPrint=False)
    # #very small differences. That means confounders are not working properly?
    # true_intvVscf= calculate_TVD(true_intv, true_cf, doPrint=False)
    # fake_intvVscf= calculate_TVD(fake_intv, fake_cf, doPrint=False)



    # tvd_diff, kl_diff = get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff)

    tvd_diff, kl_diff, td,fd = get_expected_loss_countefactuals(Exp, cur_mechs, label_generators, tvd_diff, kl_diff)
    # tvd_diff, kl_diff, td, fd = get_observational_loss(Exp, Exp.label_names, label_generators, tvd_diff, kl_diff)
    # tvd_diff, kl_diff, td, fd = get_observational_loss(Exp, obs_vars, label_generators, tvd_diff, kl_diff)
    # tvd_diff, kl_diff = get_observational_loss(Exp, obs_vars, label_generators, tvd_diff, kl_diff)

    for dist in tvd_diff:
        print(dist, tvd_diff[dist])


# cfquery = Exp.cf_queries[0]
#
# for evidence in cfquery["evidence"]:
#     for intv_key in cfquery["intervs"]:
#         true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)
#         print(f' intv:{intv_key}, evidence:{evidence}, dist:{true_cf_dist}')
#













